{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f1b13123",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([-0.76149553, -0.6481702 , -0.67783415], dtype=float32)"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import gym\n",
    "\n",
    "\n",
    "#定义环境\n",
    "class MyWrapper(gym.Wrapper):\n",
    "    def __init__(self):\n",
    "        env = gym.make('Pendulum-v1', render_mode='rgb_array')\n",
    "        super().__init__(env)\n",
    "        self.env = env\n",
    "        self.step_n = 0\n",
    "\n",
    "    def reset(self):\n",
    "        state, _ = self.env.reset()\n",
    "        self.step_n = 0\n",
    "        return state\n",
    "\n",
    "    def step(self, action):\n",
    "        state, reward, terminated, truncated, info = self.env.step(action)\n",
    "        done = terminated or truncated\n",
    "        self.step_n += 1\n",
    "        if self.step_n >= 200:\n",
    "            done = True\n",
    "        return state, reward, done, info\n",
    "\n",
    "\n",
    "env = MyWrapper()\n",
    "\n",
    "env.reset()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2d74404a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "-1.6472302675247192"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import random\n",
    "from IPython import display\n",
    "import math\n",
    "\n",
    "\n",
    "class SAC:\n",
    "    class ModelAction(torch.nn.Module):\n",
    "        def __init__(self):\n",
    "            super().__init__()\n",
    "            self.fc_state = torch.nn.Sequential(\n",
    "                torch.nn.Linear(3, 128),\n",
    "                torch.nn.ReLU(),\n",
    "            )\n",
    "            self.fc_mu = torch.nn.Linear(128, 1)\n",
    "            self.fc_std = torch.nn.Sequential(\n",
    "                torch.nn.Linear(128, 1),\n",
    "                torch.nn.Softplus(),\n",
    "            )\n",
    "\n",
    "        def forward(self, state):\n",
    "            #[b, 3] -> [b, 128]\n",
    "            state = self.fc_state(state)\n",
    "\n",
    "            #[b, 128] -> [b, 1]\n",
    "            mu = self.fc_mu(state)\n",
    "\n",
    "            #[b, 128] -> [b, 1]\n",
    "            std = self.fc_std(state)\n",
    "\n",
    "            #根据mu和std定义b个正态分布\n",
    "            dist = torch.distributions.Normal(mu, std)\n",
    "\n",
    "            #采样b个样本\n",
    "            #这里用的是rsample,表示重采样,其实就是先从一个标准正态分布中采样,然后乘以标准差,加上均值\n",
    "            sample = dist.rsample()\n",
    "\n",
    "            #样本压缩到-1,1之间,求动作\n",
    "            action = torch.tanh(sample)\n",
    "\n",
    "            #求概率对数\n",
    "            log_prob = dist.log_prob(sample)\n",
    "\n",
    "            #这个式子看不懂,但参照上下文理解,这个值应该描述的是动作的熵\n",
    "            entropy = log_prob - (1 - action.tanh()**2 + 1e-7).log()\n",
    "            entropy = -entropy\n",
    "\n",
    "            return action * 2, entropy\n",
    "\n",
    "    class ModelValue(torch.nn.Module):\n",
    "        def __init__(self):\n",
    "            super().__init__()\n",
    "            self.sequential = torch.nn.Sequential(\n",
    "                torch.nn.Linear(4, 128),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Linear(128, 128),\n",
    "                torch.nn.ReLU(),\n",
    "                torch.nn.Linear(128, 1),\n",
    "            )\n",
    "\n",
    "        def forward(self, state, action):\n",
    "            #[b, 3+1] -> [b, 4]\n",
    "            state = torch.cat([state, action], dim=1)\n",
    "\n",
    "            #[b, 4] -> [b, 1]\n",
    "            return self.sequential(state)\n",
    "\n",
    "    def __init__(self):\n",
    "        self.model_action = self.ModelAction()\n",
    "\n",
    "        self.model_value1 = self.ModelValue()\n",
    "        self.model_value2 = self.ModelValue()\n",
    "\n",
    "        self.model_value_next1 = self.ModelValue()\n",
    "        self.model_value_next2 = self.ModelValue()\n",
    "\n",
    "        self.model_value_next1.load_state_dict(self.model_value1.state_dict())\n",
    "        self.model_value_next2.load_state_dict(self.model_value2.state_dict())\n",
    "\n",
    "        #这也是一个可学习的参数\n",
    "        self.alpha = torch.tensor(math.log(0.01))\n",
    "        self.alpha.requires_grad = True\n",
    "\n",
    "        self.optimizer_action = torch.optim.Adam(\n",
    "            self.model_action.parameters(), lr=3e-4)\n",
    "        self.optimizer_value1 = torch.optim.Adam(\n",
    "            self.model_value1.parameters(), lr=3e-3)\n",
    "        self.optimizer_value2 = torch.optim.Adam(\n",
    "            self.model_value2.parameters(), lr=3e-3)\n",
    "\n",
    "        #alpha也是要更新的参数,所以这里要定义优化器\n",
    "        self.optimizer_alpha = torch.optim.Adam([self.alpha], lr=3e-4)\n",
    "\n",
    "        self.loss_fn = torch.nn.MSELoss()\n",
    "\n",
    "    def get_action(self, state):\n",
    "        state = torch.FloatTensor(state).reshape(1, 3)\n",
    "        action, _ = self.model_action(state)\n",
    "        return action.item()\n",
    "\n",
    "    def _soft_update(self, model, model_next):\n",
    "        for param, param_next in zip(model.parameters(),\n",
    "                                     model_next.parameters()):\n",
    "            #以一个小的比例更新\n",
    "            value = param_next.data * 0.995 + param.data * 0.005\n",
    "            param_next.data.copy_(value)\n",
    "\n",
    "    def _get_target(self, reward, next_state, over):\n",
    "        #首先使用model_action计算动作和动作的熵\n",
    "        #[b, 4] -> [b, 1],[b, 1]\n",
    "        action, entropy = self.model_action(next_state)\n",
    "\n",
    "        #评估next_state的价值\n",
    "        #[b, 4],[b, 1] -> [b, 1]\n",
    "        target1 = self.model_value_next1(next_state, action)\n",
    "        target2 = self.model_value_next2(next_state, action)\n",
    "\n",
    "        #取价值小的,这是出于稳定性考虑\n",
    "        #[b, 1]\n",
    "        target = torch.min(target1, target2)\n",
    "\n",
    "        #exp和log互为反操作,这里是把alpha还原了\n",
    "        #这里的操作是在target上加上了动作的熵,alpha作为权重系数\n",
    "        #[b, 1] - [b, 1] -> [b, 1]\n",
    "        target += self.alpha.exp() * entropy\n",
    "\n",
    "        #[b, 1]\n",
    "        target *= 0.99\n",
    "        target *= (1 - over)\n",
    "        target += reward\n",
    "\n",
    "        return target\n",
    "\n",
    "    def _get_loss_action(self, state):\n",
    "        #计算action和熵\n",
    "        #[b, 3] -> [b, 1],[b, 1]\n",
    "        action, entropy = self.model_action(state)\n",
    "\n",
    "        #使用两个value网络评估action的价值\n",
    "        #[b, 3],[b, 1] -> [b, 1]\n",
    "        value1 = self.model_value1(state, action)\n",
    "        value2 = self.model_value2(state, action)\n",
    "\n",
    "        #取价值小的,出于稳定性考虑\n",
    "        #[b, 1]\n",
    "        value = torch.min(value1, value2)\n",
    "\n",
    "        #alpha还原后乘以熵,这个值期望的是越大越好,但是这里是计算loss,所以符号取反\n",
    "        #[1] - [b, 1] -> [b, 1]\n",
    "        loss_action = -self.alpha.exp() * entropy\n",
    "\n",
    "        #减去value,所以value越大越好,这样loss就会越小\n",
    "        loss_action -= value\n",
    "\n",
    "        return loss_action.mean(), entropy\n",
    "\n",
    "    def _get_loss_value(self, model_value, target, state, action, next_state):\n",
    "        #计算value\n",
    "        value = model_value(state, action)\n",
    "\n",
    "        #计算loss,value的目标是要贴近target\n",
    "        loss_value = self.loss_fn(value, target)\n",
    "        return loss_value\n",
    "\n",
    "    def train(self, state, action, reward, next_state, over):\n",
    "        #对reward偏移,为了便于训练\n",
    "        reward = (reward + 8) / 8\n",
    "\n",
    "        #计算target,这个target里已经考虑了动作的熵\n",
    "        #[b, 1]\n",
    "        target = self._get_target(reward, next_state, over)\n",
    "        target = target.detach()\n",
    "\n",
    "        #计算两个value loss\n",
    "        loss_value1 = self._get_loss_value(self.model_value1, target, state,\n",
    "                                           action, next_state)\n",
    "        loss_value2 = self._get_loss_value(self.model_value2, target, state,\n",
    "                                           action, next_state)\n",
    "\n",
    "        #更新参数\n",
    "        self.optimizer_value1.zero_grad()\n",
    "        loss_value1.backward()\n",
    "        self.optimizer_value1.step()\n",
    "\n",
    "        self.optimizer_value2.zero_grad()\n",
    "        loss_value2.backward()\n",
    "        self.optimizer_value2.step()\n",
    "\n",
    "        #使用model_value计算model_action的loss\n",
    "        loss_action, entropy = self._get_loss_action(state)\n",
    "        self.optimizer_action.zero_grad()\n",
    "        loss_action.backward()\n",
    "        self.optimizer_action.step()\n",
    "\n",
    "        #熵乘以alpha就是alpha的loss\n",
    "        #[b, 1] -> [1]\n",
    "        loss_alpha = (entropy + 1).detach() * self.alpha.exp()\n",
    "        loss_alpha = loss_alpha.mean()\n",
    "\n",
    "        #更新alpha值\n",
    "        self.optimizer_alpha.zero_grad()\n",
    "        loss_alpha.backward()\n",
    "        self.optimizer_alpha.step()\n",
    "\n",
    "        #增量更新next模型\n",
    "        self._soft_update(self.model_value1, self.model_value_next1)\n",
    "        self._soft_update(self.model_value2, self.model_value_next2)\n",
    "\n",
    "\n",
    "sac = SAC()\n",
    "\n",
    "sac.train(\n",
    "    torch.randn(5, 3),\n",
    "    torch.randn(5, 1),\n",
    "    torch.randn(5, 1),\n",
    "    torch.randn(5, 3),\n",
    "    torch.zeros(5, 1).long(),\n",
    ")\n",
    "\n",
    "sac.get_action([1, 2, 3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2aaf9c43",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(200,\n",
       " ([0.7253544330596924, 0.6883755326271057, -0.6794425845146179],\n",
       "  -0.8116707801818848,\n",
       "  -0.6232792613293986,\n",
       "  [0.7350868582725525, 0.6779729723930359, -0.28491151332855225],\n",
       "  False),\n",
       " (tensor([[-0.8223,  0.5690,  8.0000],\n",
       "          [ 0.9318, -0.3629,  4.7879]]),\n",
       "  tensor([[ 2.0000],\n",
       "          [-1.4103]]),\n",
       "  tensor([[-12.8367],\n",
       "          [ -2.4323]]),\n",
       "  tensor([[-0.9790,  0.2039,  8.0000],\n",
       "          [ 0.9878, -0.1555,  4.3042]]),\n",
       "  tensor([[0],\n",
       "          [0]])))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "\n",
    "class Pool:\n",
    "    def __init__(self, limit):\n",
    "        #样本池\n",
    "        self.datas = []\n",
    "        self.limit = limit\n",
    "\n",
    "    def add(self, state, action, reward, next_state, over):\n",
    "        if isinstance(state, np.ndarray) or isinstance(state, torch.Tensor):\n",
    "            state = state.reshape(3).tolist()\n",
    "\n",
    "        action = float(action)\n",
    "\n",
    "        reward = float(reward)\n",
    "\n",
    "        if isinstance(next_state, np.ndarray) or isinstance(\n",
    "                next_state, torch.Tensor):\n",
    "            next_state = next_state.reshape(3).tolist()\n",
    "\n",
    "        over = bool(over)\n",
    "\n",
    "        self.datas.append((state, action, reward, next_state, over))\n",
    "        #数据上限,超出时从最古老的开始删除\n",
    "        while len(self.datas) > self.limit:\n",
    "            self.datas.pop(0)\n",
    "\n",
    "    #获取一批数据样本\n",
    "    def get_sample(self, size=None):\n",
    "        if size is None:\n",
    "            size = len(self)\n",
    "\n",
    "        size = min(size, len(self))\n",
    "\n",
    "        #从样本池中采样\n",
    "        samples = random.sample(self.datas, size)\n",
    "\n",
    "        #[b, 3]\n",
    "        state = torch.FloatTensor([i[0] for i in samples]).reshape(-1, 3)\n",
    "        #[b, 1]\n",
    "        action = torch.FloatTensor([i[1] for i in samples]).reshape(-1, 1)\n",
    "        #[b, 1]\n",
    "        reward = torch.FloatTensor([i[2] for i in samples]).reshape(-1, 1)\n",
    "        #[b, 3]\n",
    "        next_state = torch.FloatTensor([i[3] for i in samples]).reshape(-1, 3)\n",
    "        #[b, 1]\n",
    "        over = torch.LongTensor([i[4] for i in samples]).reshape(-1, 1)\n",
    "\n",
    "        return state, action, reward, next_state, over\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.datas)\n",
    "\n",
    "\n",
    "env_pool = Pool(10000)\n",
    "model_pool = Pool(1000)\n",
    "\n",
    "\n",
    "#先给env_pool初始化一局游戏的数据\n",
    "def _():\n",
    "    #初始化游戏\n",
    "    state = env.reset()\n",
    "\n",
    "    #玩到游戏结束为止\n",
    "    over = False\n",
    "    while not over:\n",
    "        #根据当前状态得到一个动作\n",
    "        action = sac.get_action(state)\n",
    "\n",
    "        #执行动作,得到反馈\n",
    "        next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "        #记录数据样本\n",
    "        env_pool.add(state, action, reward, next_state, over)\n",
    "\n",
    "        #更新游戏状态,开始下一个动作\n",
    "        state = next_state\n",
    "\n",
    "\n",
    "_()\n",
    "\n",
    "len(env_pool), env_pool.datas[0], env_pool.get_sample(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e5756f12",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([5, 64, 4]), torch.Size([5, 64, 4]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义主模型\n",
    "class Model(torch.nn.Module):\n",
    "    #swish激活函数\n",
    "    class Swish(torch.nn.Module):\n",
    "        def __init__(self):\n",
    "            super().__init__()\n",
    "\n",
    "        def forward(self, x):\n",
    "            return x * torch.sigmoid(x)\n",
    "\n",
    "    #定义一个工具层\n",
    "    class FCLayer(torch.nn.Module):\n",
    "        def __init__(self, in_size, out_size):\n",
    "            super().__init__()\n",
    "            self.in_size = in_size\n",
    "\n",
    "            #初始化参数\n",
    "            std = in_size**0.5\n",
    "            std *= 2\n",
    "            std = 1 / std\n",
    "\n",
    "            weight = torch.empty(5, in_size, out_size)\n",
    "            torch.nn.init.normal_(weight, mean=0.0, std=std)\n",
    "\n",
    "            #[5, in, out]\n",
    "            self.weight = torch.nn.Parameter(weight)\n",
    "\n",
    "            #[5, 1, out]\n",
    "            self.bias = torch.nn.Parameter(torch.zeros(5, 1, out_size))\n",
    "\n",
    "        def forward(self, x):\n",
    "            #x -> [5, b, in]\n",
    "\n",
    "            #[5, b, in] * [5, in, out] -> [5, b, out]\n",
    "            x = torch.bmm(x, self.weight)\n",
    "\n",
    "            #[5, b, out] + [5, 1, out] -> [5, b, out]\n",
    "            x = x + self.bias\n",
    "\n",
    "            return x\n",
    "\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "\n",
    "        self.sequential = torch.nn.Sequential(\n",
    "            self.FCLayer(4, 200),\n",
    "            self.Swish(),\n",
    "            self.FCLayer(200, 200),\n",
    "            self.Swish(),\n",
    "            self.FCLayer(200, 200),\n",
    "            self.Swish(),\n",
    "            self.FCLayer(200, 200),\n",
    "            self.Swish(),\n",
    "            self.FCLayer(200, 8),\n",
    "            torch.nn.Identity(),\n",
    "        )\n",
    "\n",
    "        self.softplus = torch.nn.Softplus()\n",
    "        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
    "\n",
    "    def forward(self, x):\n",
    "        #x -> [5, b, 4]\n",
    "\n",
    "        #[5, b, 4] -> [5, b, 8]\n",
    "        x = self.sequential(x)\n",
    "\n",
    "        #[5, b, 8] -> [5, b, 4]\n",
    "        mean = x[..., :4]\n",
    "\n",
    "        #[5, b, 8] -> [5, b, 4]\n",
    "        logvar = x[..., 4:]\n",
    "\n",
    "        #[1, 1, 4] - [5, b, 4] -> [5, b, 4]\n",
    "        logvar = 0.5 - logvar\n",
    "\n",
    "        #[1, 1, 4] - [5, b, 4] -> [5, b, 4]\n",
    "        logvar = 0.5 - self.softplus(logvar)\n",
    "\n",
    "        #[5, b, 4] - [1, 1, 4] -> [5, b, 4]\n",
    "        logvar = logvar + 10\n",
    "\n",
    "        #[5, b, 4] + [1, 1, 4] -> [5, b, 4]\n",
    "        logvar = self.softplus(logvar) - 10\n",
    "\n",
    "        #[5, b, 4],[5, b, 4]\n",
    "        return mean, logvar\n",
    "\n",
    "    def train(self):\n",
    "        state, action, reward, next_state, _ = env_pool.get_sample()\n",
    "\n",
    "        #input -> [b, 4]\n",
    "        #label -> [b, 4]\n",
    "        input = torch.cat([state, action], dim=1)\n",
    "        label = torch.cat([reward, next_state - state], dim=1)\n",
    "\n",
    "        #反复训练N次\n",
    "        for _ in range(len(input) // 64 * 20):\n",
    "            #从全量数据中抽样64个,反复抽5遍,形成5份数据\n",
    "            #[5, 64]\n",
    "            select = [torch.randperm(len(input))[:64] for _ in range(5)]\n",
    "            select = torch.stack(select)\n",
    "            #[5, b, 4],[5, b, 4]\n",
    "            input_select = input[select]\n",
    "            label_select = label[select]\n",
    "            del select\n",
    "\n",
    "            #模型计算\n",
    "            #[5, b, 4] -> [5, b, 4],[5, b, 4]\n",
    "            mean, logvar = model(input_select)\n",
    "\n",
    "            #计算loss\n",
    "            #[b, 4] - [b, 4] * [b, 4] -> [b, 4]\n",
    "            mse_loss = (mean - label_select)**2 * (-logvar).exp()\n",
    "\n",
    "            #[b, 4] -> [b] -> scala\n",
    "            mse_loss = mse_loss.mean(dim=1).mean()\n",
    "\n",
    "            #[b, 4] -> [b] -> scala\n",
    "            var_loss = logvar.mean(dim=1).mean()\n",
    "\n",
    "            loss = mse_loss + var_loss\n",
    "\n",
    "            self.optimizer.zero_grad()\n",
    "            loss.backward()\n",
    "            self.optimizer.step()\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "#model.train()\n",
    "\n",
    "a, b = model(torch.randn(5, 64, 4))\n",
    "a.shape, b.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "695d9616",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 1]) torch.Size([1, 3])\n"
     ]
    }
   ],
   "source": [
    "class MBPO():\n",
    "    def _fake_step(self, state, action):\n",
    "        state = torch.FloatTensor(state).reshape(-1, 3)\n",
    "        action = torch.FloatTensor([action]).reshape(-1, 1)\n",
    "        #state -> [b, 3]\n",
    "        #action -> [b, 1]\n",
    "\n",
    "        #[b, 4]\n",
    "        input = torch.cat([state, action], dim=1)\n",
    "\n",
    "        #重复5遍\n",
    "        #[b, 4] -> [1, b, 4] -> [5, b, 4]\n",
    "        input = input.unsqueeze(dim=0).repeat([5, 1, 1])\n",
    "\n",
    "        #模型计算\n",
    "        #[5, b, 4] -> [5, b, 4],[5, b, 4]\n",
    "        with torch.no_grad():\n",
    "            mean, std = model(input)\n",
    "        std = std.exp().sqrt()\n",
    "        del input\n",
    "\n",
    "        #means的后3列加上环境数据\n",
    "        mean[:, :, 1:] += state\n",
    "\n",
    "        #重采样\n",
    "        #[5, b ,4]\n",
    "        sample = torch.distributions.Normal(0, 1).sample(mean.shape)\n",
    "        sample = mean + sample * std\n",
    "\n",
    "        #0-4的值域采样b个元素\n",
    "        #[4, 4, 2, 4, 3, 4, 1, 3, 3, 0, 2,...\n",
    "        select = [random.choice(range(5)) for _ in range(mean.shape[1])]\n",
    "\n",
    "        #重采样结果,的结果,第0个维度,0-4随机选择,第二个维度,0-b顺序选择\n",
    "        #[5, b ,4] -> [b, 4]\n",
    "        sample = sample[select, range(mean.shape[1])]\n",
    "\n",
    "        #切分一下,就成了rewards, next_state\n",
    "        reward, next_state = sample[:, :1], sample[:, 1:]\n",
    "\n",
    "        return reward, next_state\n",
    "\n",
    "    def rollout(self):\n",
    "        states, _, _, _, _ = env_pool.get_sample(1000)\n",
    "        for state in states:\n",
    "            action = sac.get_action(state)\n",
    "            reward, next_state = self._fake_step(state, action)\n",
    "\n",
    "            model_pool.add(state, action, reward, next_state, False)\n",
    "            state = next_state\n",
    "\n",
    "\n",
    "mbpo = MBPO()\n",
    "a, b = mbpo._fake_step([1, 2, 3], 1)\n",
    "print(a.shape, b.shape)\n",
    "\n",
    "#mbpo.rollout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "10b202af",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 400 1000 -1741.1025408844582\n",
      "1 600 1000 -1482.9190299624386\n",
      "2 800 1000 -806.0623477347674\n",
      "3 1000 1000 -497.80163937120295\n",
      "4 1200 1000 -359.931277366291\n",
      "5 1400 1000 -115.15998752982819\n",
      "6 1600 1000 -1608.3966523265883\n",
      "7 1800 1000 -2.797669637970453\n",
      "8 2000 1000 -1574.017557479941\n",
      "9 2200 1000 -231.0927805243903\n",
      "10 2400 1000 -1080.970796537963\n",
      "11 2600 1000 -1508.316389930871\n",
      "12 2800 1000 -1080.8601296082209\n",
      "13 3000 1000 -126.31842834927\n",
      "14 3200 1000 -126.55971567086468\n",
      "15 3400 1000 -131.18131552262219\n",
      "16 3600 1000 -124.78880990781981\n",
      "17 3800 1000 -356.0421356280981\n",
      "18 4000 1000 -230.75949344038028\n",
      "19 4200 1000 -490.9596835028273\n"
     ]
    }
   ],
   "source": [
    "for i in range(20):\n",
    "    reward_sum = 0\n",
    "    state = env.reset()\n",
    "    over = False\n",
    "\n",
    "    step = 0\n",
    "    while not over:\n",
    "        #每隔50个step,训练一次模型\n",
    "        if step % 50 == 0:\n",
    "            model.train()\n",
    "            mbpo.rollout()\n",
    "\n",
    "        step += 1\n",
    "\n",
    "        #使用sac获取一个动作\n",
    "        action = sac.get_action(state)\n",
    "\n",
    "        #执行动作\n",
    "        next_state, reward, over, _ = env.step([action])\n",
    "\n",
    "        #累和reward\n",
    "        reward_sum += reward\n",
    "\n",
    "        #添加数据到池子里\n",
    "        env_pool.add(state, action, reward, next_state, over)\n",
    "\n",
    "        #更新状态,进入下一个循环\n",
    "        state = next_state\n",
    "\n",
    "        #更新模型\n",
    "        for _ in range(10):\n",
    "            sample = []\n",
    "            sample_env = env_pool.get_sample(32)\n",
    "            sample_model = model_pool.get_sample(32)\n",
    "\n",
    "            for (i1, i2) in zip(sample_env, sample_model):\n",
    "                i3 = torch.cat([i1, i2], dim=0)\n",
    "                sample.append(i3)\n",
    "\n",
    "            sac.train(*sample)\n",
    "\n",
    "    print(i, len(env_pool), len(model_pool), reward_sum)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
